{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68e0e186-3c69-47e4-9f6a-86e4c0eb72d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "\n",
    "from unet import UNetModel\n",
    "from diffusion import GaussianDiffusion\n",
    "\n",
    "import tqdm\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib as mpl\n",
    "mpl.rc('image', cmap='gray')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8800b4d1-3c1f-4f8e-a514-2a1a0de7f73e",
   "metadata": {},
   "source": [
    "## Load Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "085ef277-6ee3-436e-93f6-27c6d7522108",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load MNIST dataset\n",
    "device = torch.device('cuda:0')\n",
    "batch_size = 128\n",
    "\n",
    "transforms = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.ToTensor(),\n",
    "    torchvision.transforms.Pad(2), # Make MNIST images 32x32\n",
    "    torchvision.transforms.Normalize(0.5, 0.5),\n",
    "])\n",
    "mnist_train = torchvision.datasets.MNIST(root='data/', train=True, transform=transforms, download=True)\n",
    "data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=4)\n",
    "\n",
    "for batch in data_loader:\n",
    "    img, labels = batch\n",
    "    break\n",
    "    \n",
    "fig, ax = plt.subplots(1, 4, figsize=(15,15))\n",
    "for i in range(4):\n",
    "    ax[i].imshow(img[i,0,:,:].numpy())\n",
    "    ax[i].set_title(str(labels[i].item()))\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76bfb6c0-8dc8-414c-88c3-bd7fcb7919f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample from diffusion process\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "\n",
    "x0, y = next(iter(data_loader))\n",
    "t = np.random.randint(1, diffusion.T+1, batch_size).astype(int)\n",
    "xt, _ = diffusion.sample(x0, t)\n",
    "\n",
    "fig, ax = plt.subplots(2, 4, figsize=(15,8))\n",
    "for i in range(4):\n",
    "    ax[0,i].imshow(x0[i,0,:,:].numpy(), vmin=-1, vmax=1)\n",
    "    ax[0,i].set_title(str(y[i].item()))\n",
    "    \n",
    "    ax[1,i].imshow(xt[i,0,:,:].numpy(), vmin=-1, vmax=1)\n",
    "    ax[1,i].set_title(f't={t[i]}')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6abefcd1-3d19-472d-9183-a7864179fd15",
   "metadata": {},
   "source": [
    "## Train Diffusion Network"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6126fd7a-1b3c-48d4-a951-a9ef0a11a320",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train network\n",
    "net = UNetModel(image_size=32, in_channels=1, out_channels=1, \n",
    "                model_channels=64, num_res_blocks=2, channel_mult=(1,2,3,4),\n",
    "                attention_resolutions=[8,4], num_heads=4).to(device)\n",
    "net.train()\n",
    "print('Network parameters:', sum([p.numel() for p in net.parameters()]))\n",
    "\n",
    "opt = torch.optim.Adam(net.parameters(), lr=1e-4)\n",
    "\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "\n",
    "epochs = 10\n",
    "update_every = 20\n",
    "for e in range(epochs):\n",
    "    print(f'Epoch [{e+1}/{epochs}]')\n",
    "    \n",
    "    losses = []\n",
    "    batch_bar = tqdm.tqdm(data_loader)\n",
    "    for i, batch in enumerate(batch_bar):\n",
    "        img, labels = batch\n",
    "        \n",
    "        # Sample from the diffusion process\n",
    "        t = np.random.randint(1, diffusion.T+1, img.shape[0]).astype(int)\n",
    "        xt, epsilon = diffusion.sample(img, t)\n",
    "        t = torch.from_numpy(t).float().view(img.shape[0])\n",
    "        \n",
    "        # Pass through network\n",
    "        out = net(xt.float().to(device), t.to(device))\n",
    "\n",
    "        # Compute loss and backprop\n",
    "        loss = F.mse_loss(out, epsilon.float().to(device))\n",
    "        opt.zero_grad()\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "        \n",
    "        losses.append(loss.item())\n",
    "        if i % update_every == 0:\n",
    "            batch_bar.set_postfix({'Loss': np.mean(losses)})\n",
    "            losses = []\n",
    "            \n",
    "    batch_bar.set_postfix({'Loss': np.mean(losses)})\n",
    "    losses = []\n",
    "\n",
    "    # Visualize sample\n",
    "    with torch.no_grad():\n",
    "        net.eval()\n",
    "        x = diffusion.inverse(net, shape=(1,32,32), device=device)\n",
    "        net.train()\n",
    "    \n",
    "    plt.figure(figsize=(5,5))\n",
    "    plt.imshow(x.cpu().numpy()[0,0,:,:], vmin=-1, vmax=1)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e386361b-4ff5-436c-afd0-44826c646768",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample from the learned diffusion process\n",
    "diffusion = GaussianDiffusion(T=1000, schedule='linear')\n",
    "\n",
    "with torch.no_grad():\n",
    "    net.eval()\n",
    "    x = diffusion.inverse(net, shape=(1,32,32), device=device)\n",
    "    net.train()\n",
    "\n",
    "plt.figure(figsize=(5,5))\n",
    "plt.imshow(x.cpu().numpy()[0,0,:,:], vmin=-1, vmax=1)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "17441f20-fed2-4100-9f4e-4b3d32f3dbdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save/Load model\n",
    "#torch.save(net.state_dict(), 'models/mnist_unet.pth')\n",
    "#print('Saved model')\n",
    "\n",
    "net = UNetModel(image_size=32, in_channels=1, out_channels=1, \n",
    "                model_channels=64, num_res_blocks=2, channel_mult=(1,2,3,4),\n",
    "                attention_resolutions=[8,4], num_heads=4).to(device)\n",
    "net.load_state_dict(torch.load('models/mnist_unet.pth'))\n",
    "net.to(device)\n",
    "net.train()\n",
    "print('Loaded model')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "torch_geo_simple",
   "language": "python",
   "name": "conda-env-torch_geo_simple-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
